{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "# !pip install wget\n",
    "# !pip install git+https://github.com/NVIDIA/apex.git\n",
    "# !pip install nemo_toolkit[nlp]\n",
    "# !pip install unidecode\n",
    "import os\n",
    "import nemo\n",
    "import nemo.collections.nlp as nemo_nlp\n",
    "import numpy as np\n",
    "import time\n",
    "import errno\n",
    "\n",
    "from nemo.backends.pytorch.common.losses import CrossEntropyLossNM\n",
    "from nemo.collections.nlp.data.datasets import TextClassificationDataDesc\n",
    "from nemo.collections.nlp.nm.data_layers import BertTextClassificationDataLayer\n",
    "from nemo.collections.nlp.nm.trainables import SequenceClassifier\n",
    "from nemo.collections.nlp.callbacks.text_classification_callback import eval_epochs_done_callback, eval_iter_callback\n",
    "from nemo.utils.lr_policies import get_lr_policy\n",
    "from nemo import logging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction\n",
    "BioBERT has the same network architecture as the original BERT, but instead of Wikipedia and BookCorpus it is pretrained on PubMed, a large biomedical text corpus, which achieves better performance in biomedical downstream tasks, such as question answering(QA), named entity recognition(NER) and relationship extraction(RE). This model was trained for 1M steps. For more information please refer to the original paper https://academic.oup.com/bioinformatics/article/36/4/1234/5566506.  For details about BERT please refer to https://ngc.nvidia.com/catalog/models/nvidia:bertbaseuncasedfornemo.\n",
    "\n",
    "\n",
    "In this notebook we're going to showcase how to train BioBERT on a biomedical relation extraction (RE) dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download model checkpoint\n",
    "Download BioBert/BioMegatron checkpoints from  NGC: https://ngc.nvidia.com/catalog/models and put the encoder weights \n",
    "at `./checkpoints/biobert/BERT.pt` or `./checkpoints/biomegatron/BERT.pt` and the model configuration file at `./checkpoints/biobert/bert_config.json` or `./checkpoints/biomegatron/bert_config.json`. \n",
    "For BioBERT download e.g. https://ngc.nvidia.com/catalog/models/nvidia:biobertbasecasedfornemo.\n",
    "For BioMegatron download e.g. https://ngc.nvidia.com/catalog/models/nvidia:biomegatron345muncased."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set which model to use.\n",
    "model_type=\"biobert\" # \"biomegatron\"\n",
    "base_checkpoint_path={'biobert': './checkpoints/biobert/', 'biomegatron': './checkpoints/biomegatron/'}\n",
    "pretrained_model_name={'biobert': 'bert-base-cased', 'biomegatron': 'megatron-bert-uncased'}\n",
    "do_lower_case={'biobert': False, 'biomegatron': True}\n",
    "work_dir={'biobert': 'output_re_biobert', 'biomegatron': 'output_re_biomegatron'}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the checkpoints are available from NGC: https://ngc.nvidia.com/catalog/models\n",
    "CHECKPOINT_ENCODER = os.path.join(base_checkpoint_path[model_type], 'BERT.pt') # Model encoder checkpoint file\n",
    "CHECKPOINT_CONFIG = os.path.join(base_checkpoint_path[model_type], 'bert_config.json') # Model configuration file\n",
    "    \n",
    "if not os.path.exists(CHECKPOINT_ENCODER):\n",
    "    raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), CHECKPOINT_ENCODER)\n",
    "\n",
    "if not os.path.exists(CHECKPOINT_CONFIG):\n",
    "    raise OSError(errno.ENOENT, os.strerror(errno.ENOENT), CHECKPOINT_CONFIG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download training data\n",
    "In this example we download the RE dataset chemprot to ./datasets/chemprot and process it with text_classification/data/import_datasets.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#download https://github.com/arwhirang/recursive_chemprot/blob/master/Demo/tree_LSTM/data/chemprot-data_treeLSTM.zip and extract it into ./datasets/chemprot\n",
    "data_dir=\"./datasets\"\n",
    "dataset=\"chemprot\"\n",
    "if not os.path.exists(f\"{data_dir}/{dataset}\"):\n",
    "    !mkdir -p $data_dir/$dataset\n",
    "    !wget \"https://github.com/arwhirang/recursive_chemprot/blob/master/Demo/tree_LSTM/data/chemprot-data_treeLSTM.zip?raw=true\" -O data.zip\n",
    "    !unzip data.zip -d $data_dir/$dataset\n",
    "    !rm data.zip\n",
    "\n",
    "!python ../text_classification/data/import_datasets.py --source_data_dir=$data_dir/$dataset --target_data_dir=$data_dir/$dataset --dataset_name=$dataset\n",
    "!ls -l $data_dir/$dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the previous step, you should have a ./datasets/chemprot folder that contains the following files:\n",
    "- train.tsv\n",
    "- test.tsv\n",
    "- dev.tsv\n",
    "- label_mapping.tsv\n",
    "\n",
    "The format of the data described in NeMo docs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Neural Modules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_checkpoint=CHECKPOINT_ENCODER # language model encoder file\n",
    "model_config=CHECKPOINT_CONFIG # model configuration file\n",
    "work_dir=work_dir[model_type]\n",
    "train_data_text_file=f\"{data_dir}/{dataset}/train.tsv\"\n",
    "eval_data_text_file=f\"{data_dir}/{dataset}/dev.tsv\"\n",
    "fc_dropout=0.1\n",
    "max_seq_length=128\n",
    "batch_size=32\n",
    "num_output_layers=1\n",
    "\n",
    "nf = nemo.core.NeuralModuleFactory(\n",
    "    placement=nemo.core.DeviceType.GPU\n",
    ")\n",
    "model = nemo_nlp.nm.trainables.get_pretrained_lm_model(\n",
    "        config=model_config, pretrained_model_name=pretrained_model_name[model_type], checkpoint=model_checkpoint\n",
    "    )\n",
    "tokenizer = nemo.collections.nlp.data.tokenizers.get_tokenizer(\n",
    "    tokenizer_name='nemobert',\n",
    "    pretrained_model_name=pretrained_model_name[model_type],\n",
    "    do_lower_case=do_lower_case[model_type]\n",
    ")\n",
    "hidden_size = model.hidden_size\n",
    "data_desc = TextClassificationDataDesc(data_dir=f\"{data_dir}/{dataset}\", modes=['train', 'dev'])\n",
    "classifier = nemo_nlp.nm.trainables.SequenceClassifier(    \n",
    "    hidden_size=hidden_size,\n",
    "    num_classes=data_desc.num_labels,\n",
    "    dropout=fc_dropout,\n",
    "    num_layers=num_output_layers,\n",
    "    log_softmax=False,\n",
    ")\n",
    "task_loss = CrossEntropyLossNM(weight=None)\n",
    "train_data_layer = BertTextClassificationDataLayer(\n",
    "    tokenizer=tokenizer,\n",
    "    input_file=train_data_text_file,\n",
    "    max_seq_length=max_seq_length,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    "    use_cache=True\n",
    ")\n",
    "eval_data_layer = BertTextClassificationDataLayer(\n",
    "    tokenizer=tokenizer,\n",
    "    input_file=eval_data_text_file,\n",
    "    max_seq_length=max_seq_length,\n",
    "    batch_size=batch_size,\n",
    "    shuffle=False,\n",
    "    use_cache=False\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating Neural graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = train_data_layer()\n",
    "train_hidden_states = model(input_ids=train_data.input_ids, token_type_ids=train_data.input_type_ids, attention_mask=train_data.input_mask)\n",
    "train_logits = classifier(hidden_states=train_hidden_states)\n",
    "loss = task_loss(logits=train_logits, labels=train_data.labels)\n",
    "# If you're training on multiple GPUs, this should be\n",
    "# len(train_data_layer) // (batch_size * batches_per_step * num_gpus)\n",
    "train_steps_per_epoch = len(train_data_layer) // batch_size\n",
    "logging.info(f\"doing {train_steps_per_epoch} steps per epoch\")\n",
    "\n",
    "eval_data = eval_data_layer()\n",
    "eval_hidden_states = model(input_ids=eval_data.input_ids, token_type_ids=eval_data.input_type_ids, attention_mask=eval_data.input_mask)\n",
    "eval_logits = classifier(hidden_states=eval_hidden_states)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create Callbacks\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_callback = nemo.core.SimpleLossLoggerCallback(\n",
    "    tensors=[loss],\n",
    "    print_func=lambda x: logging.info(\"Loss: {:.3f}\".format(x[0].item())),\n",
    "    get_tb_values=lambda x: [[\"loss\", x[0]]],\n",
    "    step_freq=100,\n",
    "    tb_writer=nf.tb_writer,\n",
    ")\n",
    "\n",
    "# Callback to evaluate the model\n",
    "eval_callback = nemo.core.EvaluatorCallback(\n",
    "        eval_tensors=[eval_logits, eval_data.labels],\n",
    "        user_iter_callback=lambda x, y: eval_iter_callback(x, y, eval_data_layer),\n",
    "        user_epochs_done_callback=lambda x: eval_epochs_done_callback(x, f'{nf.work_dir}/graphs'),\n",
    "        tb_writer=nf.tb_writer,\n",
    "        eval_step=500,\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training\n",
    "Training could take several minutes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs=3\n",
    "lr_warmup_proportion=0.1\n",
    "lr=3e-5\n",
    "weight_decay=0.01\n",
    "lr_policy_fn = get_lr_policy(\"WarmupAnnealing\", total_steps=num_epochs * train_steps_per_epoch, warmup_ratio=lr_warmup_proportion\n",
    ")\n",
    "nf.train(\n",
    "    tensors_to_optimize=[loss],\n",
    "    callbacks=[train_callback, eval_callback],\n",
    "    lr_policy=lr_policy_fn,\n",
    "    optimizer=\"adam_w\",\n",
    "    optimization_params={\"num_epochs\": num_epochs, \"lr\": lr, \"weight_decay\": weight_decay},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The result should look something like this:\n",
    "```\n",
    "precision    recall  f1-score   support\n",
    "    \n",
    "               0     0.7328    0.8348    0.7805       115\n",
    "               1     0.9402    0.9291    0.9346      7950\n",
    "               2     0.8311    0.9146    0.8708       199\n",
    "               3     0.6400    0.6302    0.6351       457\n",
    "               4     0.8002    0.8317    0.8156      1093\n",
    "               5     0.7228    0.7518    0.7370       548\n",
    "    \n",
    "        accuracy                         0.8949     10362\n",
    "       macro avg     0.7778    0.8153    0.7956     10362\n",
    "    weighted avg     0.8963    0.8949    0.8954     10362\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}